import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as grad

from .transformer import Transformer

class PointNetBase(nn.Module):

    def __init__(self, num_points=2000, K=3):
        super(PointNetBase, self).__init__()
        self.input_transformer = Transformer(num_points, K)

        self.embedding_transformer = Transformer(num_points, 64)

        self.mlp1 = nn.Sequential(
            nn.Conv1d(K, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU())

        self.mlp2 = nn.Sequential(
            nn.Conv1d(64, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 1024, 1),
            nn.BatchNorm1d(1024),
            nn.ReLU())

    def forward(self, x):
        N = x.shape[2]
        T1 = self.input_transformer(x)
        x = torch.bmm(T1, x)
        x = self.mlp1(x)
        T2 = self.embedding_transformer(x)
        local_embedding = torch.bmm(T2, x)
        global_feature = self.mlp2(local_embedding)
        global_feature = F.max_pool1d(global_feature, N).squeeze(2)

        return global_feature, local_embedding, T2

